13 research outputs found
Interpretable by Design: Learning Predictors by Composing Interpretable Queries
There is a growing concern about typically opaque decision-making with
high-performance machine learning algorithms. Providing an explanation of the
reasoning process in domain-specific terms can be crucial for adoption in
risk-sensitive domains such as healthcare. We argue that machine learning
algorithms should be interpretable by design and that the language in which
these interpretations are expressed should be domain- and task-dependent.
Consequently, we base our model's prediction on a family of user-defined and
task-specific binary functions of the data, each having a clear interpretation
to the end-user. We then minimize the expected number of queries needed for
accurate prediction on any given input. As the solution is generally
intractable, following prior work, we choose the queries sequentially based on
information gain. However, in contrast to previous work, we need not assume the
queries are conditionally independent. Instead, we leverage a stochastic
generative model (VAE) and an MCMC algorithm (Unadjusted Langevin) to select
the most informative query about the input based on previous query-answers.
This enables the online determination of a query chain of whatever depth is
required to resolve prediction ambiguities. Finally, experiments on vision and
NLP tasks demonstrate the efficacy of our approach and its superiority over
post-hoc explanations.Comment: 29 pages, 14 figures. Accepted as a Regular Paper in Transactions on
Pattern Analysis and Machine Intelligenc
Variational Information Pursuit for Interpretable Predictions
There is a growing interest in the machine learning community in developing
predictive algorithms that are "interpretable by design". Towards this end,
recent work proposes to make interpretable decisions by sequentially asking
interpretable queries about data until a prediction can be made with high
confidence based on the answers obtained (the history). To promote short
query-answer chains, a greedy procedure called Information Pursuit (IP) is
used, which adaptively chooses queries in order of information gain. Generative
models are employed to learn the distribution of query-answers and labels,
which is in turn used to estimate the most informative query. However, learning
and inference with a full generative model of the data is often intractable for
complex tasks. In this work, we propose Variational Information Pursuit (V-IP),
a variational characterization of IP which bypasses the need for learning
generative models. V-IP is based on finding a query selection strategy and a
classifier that minimizes the expected cross-entropy between true and predicted
labels. We then demonstrate that the IP strategy is the optimal solution to
this problem. Therefore, instead of learning generative models, we can use our
optimal strategy to directly pick the most informative query given any history.
We then develop a practical algorithm by defining a finite-dimensional
parameterization of our strategy and classifier using deep networks and train
them end-to-end using our objective. Empirically, V-IP is 10-100x faster than
IP on different Vision and NLP tasks with competitive performance. Moreover,
V-IP finds much shorter query chains when compared to reinforcement learning
which is typically used in sequential-decision-making problems. Finally, we
demonstrate the utility of V-IP on challenging tasks like medical diagnosis
where the performance is far superior to the generative modelling approach.Comment: Code is available at
https://github.com/ryanchankh/VariationalInformationPursui
Unsupervised Manifold Linearizing and Clustering
We consider the problem of simultaneously clustering and learning a linear
representation of data lying close to a union of low-dimensional manifolds, a
fundamental task in machine learning and computer vision. When the manifolds
are assumed to be linear subspaces, this reduces to the classical problem of
subspace clustering, which has been studied extensively over the past two
decades. Unfortunately, many real-world datasets such as natural images can not
be well approximated by linear subspaces. On the other hand, numerous works
have attempted to learn an appropriate transformation of the data, such that
data is mapped from a union of general non-linear manifolds to a union of
linear subspaces (with points from the same manifold being mapped to the same
subspace). However, many existing works have limitations such as assuming
knowledge of the membership of samples to clusters, requiring high sampling
density, or being shown theoretically to learn trivial representations. In this
paper, we propose to optimize the Maximal Coding Rate Reduction metric with
respect to both the data representation and a novel doubly stochastic cluster
membership, inspired by state-of-the-art subspace clustering results. We give a
parameterization of such a representation and membership, allowing efficient
mini-batching and one-shot initialization. Experiments on CIFAR-10, -20, -100,
and TinyImageNet-200 datasets show that the proposed method is much more
accurate and scalable than state-of-the-art deep clustering methods, and
further learns a latent linear representation of the data
White-Box Transformers via Sparse Rate Reduction
In this paper, we contend that the objective of representation learning is to
compress and transform the distribution of the data, say sets of tokens,
towards a mixture of low-dimensional Gaussian distributions supported on
incoherent subspaces. The quality of the final representation can be measured
by a unified objective function called sparse rate reduction. From this
perspective, popular deep networks such as transformers can be naturally viewed
as realizing iterative schemes to optimize this objective incrementally.
Particularly, we show that the standard transformer block can be derived from
alternating optimization on complementary parts of this objective: the
multi-head self-attention operator can be viewed as a gradient descent step to
compress the token sets by minimizing their lossy coding rate, and the
subsequent multi-layer perceptron can be viewed as attempting to sparsify the
representation of the tokens. This leads to a family of white-box
transformer-like deep network architectures which are mathematically fully
interpretable. Despite their simplicity, experiments show that these networks
indeed learn to optimize the designed objective: they compress and sparsify
representations of large-scale real-world vision datasets such as ImageNet, and
achieve performance very close to thoroughly engineered transformers such as
ViT. Code is at \url{https://github.com/Ma-Lab-Berkeley/CRATE}.Comment: 33 pages, 11 figure